{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "import numpy as np\n",
    "import os\n",
    "import shutil\n",
    "from utils import log_progress\n",
    "\n",
    "np.random.seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(12500, 12500)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "files = glob.glob('train/*')\n",
    "\n",
    "cat_files = [fn for fn in files if 'cat' in fn]\n",
    "dog_files = [fn for fn in files if 'dog' in fn]\n",
    "len(cat_files), len(dog_files)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cat datasets: (1500,) (500,) (500,)\n",
      "Dog datasets: (1500,) (500,) (500,)\n"
     ]
    }
   ],
   "source": [
    "cat_train = np.random.choice(cat_files, size=1500, replace=False)\n",
    "dog_train = np.random.choice(dog_files, size=1500, replace=False)\n",
    "cat_files = list(set(cat_files) - set(cat_train))\n",
    "dog_files = list(set(dog_files) - set(dog_train))\n",
    "\n",
    "cat_val = np.random.choice(cat_files, size=500, replace=False)\n",
    "dog_val = np.random.choice(dog_files, size=500, replace=False)\n",
    "cat_files = list(set(cat_files) - set(cat_val))\n",
    "dog_files = list(set(dog_files) - set(dog_val))\n",
    "\n",
    "cat_test = np.random.choice(cat_files, size=500, replace=False)\n",
    "dog_test = np.random.choice(dog_files, size=500, replace=False)\n",
    "\n",
    "print('Cat datasets:', cat_train.shape, cat_val.shape, cat_test.shape)\n",
    "print('Dog datasets:', dog_train.shape, dog_val.shape, dog_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dir = 'training_data'\n",
    "val_dir = 'validation_data'\n",
    "test_dir = 'test_data'\n",
    "\n",
    "train_files = np.concatenate([cat_train, dog_train])\n",
    "validate_files = np.concatenate([cat_val, dog_val])\n",
    "test_files = np.concatenate([cat_test, dog_test])\n",
    "\n",
    "os.mkdir(train_dir) if not os.path.isdir(train_dir) else None\n",
    "os.mkdir(val_dir) if not os.path.isdir(val_dir) else None\n",
    "os.mkdir(test_dir) if not os.path.isdir(test_dir) else None\n",
    "\n",
    "for fn in log_progress(train_files, name='Training Images'):\n",
    "    shutil.copy(fn, train_dir)\n",
    "\n",
    "for fn in log_progress(validate_files, name='Validation Images'):\n",
    "    shutil.copy(fn, val_dir)\n",
    "    \n",
    "for fn in log_progress(test_files, name='Test Images'):\n",
    "    shutil.copy(fn, test_dir)"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [conda root]",
   "language": "python",
   "name": "conda-root-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  },
  "widgets": {
   "state": {
    "707cda043c794c6bbccbf3e3f0264e09": {
     "views": [
      {
       "cell_index": 5
      }
     ]
    },
    "76235700d807469ea94675fe7f2b6187": {
     "views": [
      {
       "cell_index": 5
      }
     ]
    },
    "7aadcb865964421fb6f976525e21adca": {
     "views": [
      {
       "cell_index": 5
      }
     ]
    }
   },
   "version": "1.2.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
